{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# **Homework 5 - Sequence-to-sequence**\n",
    "\n",
    "If you have any questions, feel free to email us at: ntu-ml-2021spring-ta@googlegroups.com"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### (4/21 Updates)\n",
    "1. Link to reference [training curves](https://wandb.ai/george0828zhang/hw5.seq2seq.new).\n",
    "\n",
    "### (4/14 Updates) \n",
    "1. Link to tutorial video [part 1](https://youtu.be/1pjS5_L5REI) [part 2](https://youtu.be/3XX9d0ymKgQ).\n",
    "2. Now defaults to load `\"avg_last_5_checkpoint.pt\"` to generate prediction.\n",
    "3. Expected run time on Colab with Tesla T4\n",
    "\n",
    "|Baseline|Details|Total Time|\n",
    "|-|:-:|:-:|\n",
    "|Simple|2m 15s $\\times$30 epochs|1hr 8m|\n",
    "|Medium|4m $\\times$30 epochs|2hr|\n",
    "|Strong|8m $\\times$30 epochs (backward)<br>+1hr (back-translation)<br>+15m $\\times$30 epochs (forward)|12hr 30m|"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sequence-to-Sequence Introduction\n",
    "- Typical sequence-to-sequence (seq2seq) models are encoder-decoder models, which usually consists of two parts, the encoder and decoder, respectively. These two parts can be implemented with recurrent neural network (RNN) or transformer, primarily to deal with input/output sequences of dynamic length.\n",
    "- **Encoder** encodes a sequence of inputs, such as text, video or audio, into a single vector, which can be viewed as the abstractive representation of the inputs, containing information of the whole sequence.\n",
    "- **Decoder** decodes the vector output of encoder one step at a time, until the final output sequence is complete. Every decoding step is affected by previous step(s). Generally, one would add \"< BOS >\" at the begining of the sequence to indicate start of decoding, and \"< EOS >\" at the end to indicate end of decoding.\n",
    "\n",
    "![seq2seq](https://i.imgur.com/0zeDyuI.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Homework Description\n",
    "- English to Chinese (Traditional) Translation\n",
    "  - Input: an English sentence         (e.g.\t\ttom is a student .)\n",
    "  - Output: the Chinese translation  (e.g. \t\t湯姆 是 個 學生 。)\n",
    "\n",
    "- TODO\n",
    "    - Train a simple RNN seq2seq to acheive translation\n",
    "    - Switch to transformer model to boost performance\n",
    "    - Apply Back-translation to furthur boost performance"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Download and import required packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install 'torch>=1.6.0' editdistance matplotlib sacrebleu sacremoses sentencepiece tqdm wandb\n",
    "!pip install --upgrade jupyter ipywidgets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!git clone https://github.com/pytorch/fairseq.git\n",
    "!cd fairseq && git checkout 9a1c497\n",
    "!pip install --upgrade ./fairseq/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import pdb\n",
    "import pprint\n",
    "import logging\n",
    "import os\n",
    "import random\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils import data\n",
    "import numpy as np\n",
    "import tqdm.auto as tqdm\n",
    "from pathlib import Path\n",
    "from argparse import Namespace\n",
    "from fairseq import utils\n",
    "\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fix random seed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 73\n",
    "random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)  \n",
    "np.random.seed(seed)  \n",
    "torch.backends.cudnn.benchmark = False\n",
    "torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dataset Information\n",
    "\n",
    "## En-Zh Bilingual Parallel Corpus\n",
    "* [TED2020](#reimers-2020-multilingual-sentence-bert)\n",
    "    - Raw: 398,066 (sentences)   \n",
    "    - Processed: 393,980 (sentences)\n",
    "    \n",
    "\n",
    "## Testdata\n",
    "- Size: 4,000 (sentences)\n",
    "- **Chinese translation is undisclosed. The provided (.zh) file is psuedo translation, each line is a '。'**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dataset Download"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Install megatools (optional)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#!apt-get install megatools"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download and extract"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = './DATA/rawdata'\n",
    "dataset_name = 'ted2020'\n",
    "urls = (\n",
    "    '\"https://onedrive.live.com/download?cid=3E549F3B24B238B4&resid=3E549F3B24B238B4%214989&authkey=AGgQ-DaR8eFSl1A\"', \n",
    "    '\"https://onedrive.live.com/download?cid=3E549F3B24B238B4&resid=3E549F3B24B238B4%214987&authkey=AA4qP_azsicwZZM\"',\n",
    "# # If the above links die, use the following instead. \n",
    "#     \"https://www.csie.ntu.edu.tw/~r09922057/ML2021-hw5/ted2020.tgz\",\n",
    "#     \"https://www.csie.ntu.edu.tw/~r09922057/ML2021-hw5/test.tgz\",\n",
    "# # If the above links die, use the following instead. \n",
    "#     \"https://mega.nz/#!vEcTCISJ!3Rw0eHTZWPpdHBTbQEqBDikDEdFPr7fI8WxaXK9yZ9U\",\n",
    "#     \"https://mega.nz/#!zNcnGIoJ!oPJX9AvVVs11jc0SaK6vxP_lFUNTkEcK2WbxJpvjU5Y\",\n",
    ")\n",
    "file_names = (\n",
    "    'ted2020.tgz', # train & dev\n",
    "    'test.tgz', # test\n",
    ")\n",
    "prefix = Path(data_dir).absolute() / dataset_name\n",
    "\n",
    "prefix.mkdir(parents=True, exist_ok=True)\n",
    "for u, f in zip(urls, file_names):\n",
    "    path = prefix/f\n",
    "    if not path.exists():\n",
    "        if 'mega' in u:\n",
    "            !megadl {u} --path {path}\n",
    "        else:\n",
    "            !wget {u} -O {path}\n",
    "    if path.suffix == \".tgz\":\n",
    "        !tar -xvf {path} -C {prefix}\n",
    "    elif path.suffix == \".zip\":\n",
    "        !unzip -o {path} -d {prefix}\n",
    "!mv {prefix/'raw.en'} {prefix/'train_dev.raw.en'}\n",
    "!mv {prefix/'raw.zh'} {prefix/'train_dev.raw.zh'}\n",
    "!mv {prefix/'test.en'} {prefix/'test.raw.en'}\n",
    "!mv {prefix/'test.zh'} {prefix/'test.raw.zh'}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Language"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "src_lang = 'en'\n",
    "tgt_lang = 'zh'\n",
    "\n",
    "data_prefix = f'{prefix}/train_dev.raw'\n",
    "test_prefix = f'{prefix}/test.raw'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!head {data_prefix+'.'+src_lang} -n 5\n",
    "!head {data_prefix+'.'+tgt_lang} -n 5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preprocess files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "def strQ2B(ustring):\n",
    "    \"\"\"Full width -> half width\"\"\"\n",
    "    # reference:https://ithelp.ithome.com.tw/articles/10233122\n",
    "    ss = []\n",
    "    for s in ustring:\n",
    "        rstring = \"\"\n",
    "        for uchar in s:\n",
    "            inside_code = ord(uchar)\n",
    "            if inside_code == 12288:  # Full width space: direct conversion\n",
    "                inside_code = 32\n",
    "            elif (inside_code >= 65281 and inside_code <= 65374):  # Full width chars (except space) conversion\n",
    "                inside_code -= 65248\n",
    "            rstring += chr(inside_code)\n",
    "        ss.append(rstring)\n",
    "    return ''.join(ss)\n",
    "                \n",
    "def clean_s(s, lang):\n",
    "    if lang == 'en':\n",
    "        s = re.sub(r\"\\([^()]*\\)\", \"\", s) # remove ([text])\n",
    "        s = s.replace('-', '') # remove '-'\n",
    "        s = re.sub('([.,;!?()\\\"])', r' \\1 ', s) # keep punctuation\n",
    "    elif lang == 'zh':\n",
    "        s = strQ2B(s) # Q2B\n",
    "        s = re.sub(r\"\\([^()]*\\)\", \"\", s) # remove ([text])\n",
    "        s = s.replace(' ', '')\n",
    "        s = s.replace('—', '')\n",
    "        s = s.replace('“', '\"')\n",
    "        s = s.replace('”', '\"')\n",
    "        s = s.replace('_', '')\n",
    "        s = re.sub('([。,;!?()\\\"~「」])', r' \\1 ', s) # keep punctuation\n",
    "    s = ' '.join(s.strip().split())\n",
    "    return s\n",
    "\n",
    "def len_s(s, lang):\n",
    "    if lang == 'zh':\n",
    "        return len(s)\n",
    "    return len(s.split())\n",
    "\n",
    "def clean_corpus(prefix, l1, l2, ratio=9, max_len=1000, min_len=1):\n",
    "    if Path(f'{prefix}.clean.{l1}').exists() and Path(f'{prefix}.clean.{l2}').exists():\n",
    "        print(f'{prefix}.clean.{l1} & {l2} exists. skipping clean.')\n",
    "        return\n",
    "    with open(f'{prefix}.{l1}', 'r') as l1_in_f:\n",
    "        with open(f'{prefix}.{l2}', 'r') as l2_in_f:\n",
    "            with open(f'{prefix}.clean.{l1}', 'w') as l1_out_f:\n",
    "                with open(f'{prefix}.clean.{l2}', 'w') as l2_out_f:\n",
    "                    for s1 in l1_in_f:\n",
    "                        s1 = s1.strip()\n",
    "                        s2 = l2_in_f.readline().strip()\n",
    "                        s1 = clean_s(s1, l1)\n",
    "                        s2 = clean_s(s2, l2)\n",
    "                        s1_len = len_s(s1, l1)\n",
    "                        s2_len = len_s(s2, l2)\n",
    "                        if min_len > 0: # remove short sentence\n",
    "                            if s1_len < min_len or s2_len < min_len:\n",
    "                                continue\n",
    "                        if max_len > 0: # remove long sentence\n",
    "                            if s1_len > max_len or s2_len > max_len:\n",
    "                                continue\n",
    "                        if ratio > 0: # remove by ratio of length\n",
    "                            if s1_len/s2_len > ratio or s2_len/s1_len > ratio:\n",
    "                                continue\n",
    "                        print(s1, file=l1_out_f)\n",
    "                        print(s2, file=l2_out_f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_corpus(data_prefix, src_lang, tgt_lang)\n",
    "clean_corpus(test_prefix, src_lang, tgt_lang, ratio=-1, min_len=-1, max_len=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!head {data_prefix+'.clean.'+src_lang} -n 5\n",
    "!head {data_prefix+'.clean.'+tgt_lang} -n 5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Split into train/valid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "valid_ratio = 0.01 # 3000~4000 would suffice\n",
    "train_ratio = 1 - valid_ratio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if (prefix/f'train.clean.{src_lang}').exists() \\\n",
    "and (prefix/f'train.clean.{tgt_lang}').exists() \\\n",
    "and (prefix/f'valid.clean.{src_lang}').exists() \\\n",
    "and (prefix/f'valid.clean.{tgt_lang}').exists():\n",
    "    print(f'train/valid splits exists. skipping split.')\n",
    "else:\n",
    "    line_num = sum(1 for line in open(f'{data_prefix}.clean.{src_lang}'))\n",
    "    labels = list(range(line_num))\n",
    "    random.shuffle(labels)\n",
    "    for lang in [src_lang, tgt_lang]:\n",
    "        train_f = open(os.path.join(data_dir, dataset_name, f'train.clean.{lang}'), 'w')\n",
    "        valid_f = open(os.path.join(data_dir, dataset_name, f'valid.clean.{lang}'), 'w')\n",
    "        count = 0\n",
    "        for line in open(f'{data_prefix}.clean.{lang}', 'r'):\n",
    "            if labels[count]/line_num < train_ratio:\n",
    "                train_f.write(line)\n",
    "            else:\n",
    "                valid_f.write(line)\n",
    "            count += 1\n",
    "        train_f.close()\n",
    "        valid_f.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Subword Units \n",
    "Out of vocabulary (OOV) has been a major problem in machine translation. This can be alleviated by using subword units.\n",
    "- We will use the [sentencepiece](#kudo-richardson-2018-sentencepiece) package\n",
    "- select 'unigram' or 'byte-pair encoding (BPE)' algorithm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sentencepiece as spm\n",
    "vocab_size = 8000\n",
    "if (prefix/f'spm{vocab_size}.model').exists():\n",
    "    print(f'{prefix}/spm{vocab_size}.model exists. skipping spm_train.')\n",
    "else:\n",
    "    spm.SentencePieceTrainer.train(\n",
    "        input=','.join([f'{prefix}/train.clean.{src_lang}',\n",
    "                        f'{prefix}/valid.clean.{src_lang}',\n",
    "                        f'{prefix}/train.clean.{tgt_lang}',\n",
    "                        f'{prefix}/valid.clean.{tgt_lang}']),\n",
    "        model_prefix=prefix/f'spm{vocab_size}',\n",
    "        vocab_size=vocab_size,\n",
    "        character_coverage=1,\n",
    "        model_type='unigram', # 'bpe' works as well\n",
    "        input_sentence_size=1e6,\n",
    "        shuffle_input_sentence=True,\n",
    "        normalization_rule_name='nmt_nfkc_cf',\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "spm_model = spm.SentencePieceProcessor(model_file=str(prefix/f'spm{vocab_size}.model'))\n",
    "in_tag = {\n",
    "    'train': 'train.clean',\n",
    "    'valid': 'valid.clean',\n",
    "    'test': 'test.raw.clean',\n",
    "}\n",
    "for split in ['train', 'valid', 'test']:\n",
    "    for lang in [src_lang, tgt_lang]:\n",
    "        out_path = prefix/f'{split}.{lang}'\n",
    "        if out_path.exists():\n",
    "            print(f\"{out_path} exists. skipping spm_encode.\")\n",
    "        else:\n",
    "            with open(prefix/f'{split}.{lang}', 'w') as out_f:\n",
    "                with open(prefix/f'{in_tag[split]}.{lang}', 'r') as in_f:\n",
    "                    for line in in_f:\n",
    "                        line = line.strip()\n",
    "                        tok = spm_model.encode(line, out_type=str)\n",
    "                        print(' '.join(tok), file=out_f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!head {data_dir+'/'+dataset_name+'/train.'+src_lang} -n 5\n",
    "!head {data_dir+'/'+dataset_name+'/train.'+tgt_lang} -n 5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Binarize the data with fairseq"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "binpath = Path('./DATA/data-bin', dataset_name)\n",
    "if binpath.exists():\n",
    "    print(binpath, \"exists, will not overwrite!\")\n",
    "else:\n",
    "    !python -m fairseq_cli.preprocess \\\n",
    "        --source-lang {src_lang}\\\n",
    "        --target-lang {tgt_lang}\\\n",
    "        --trainpref {prefix/'train'}\\\n",
    "        --validpref {prefix/'valid'}\\\n",
    "        --testpref {prefix/'test'}\\\n",
    "        --destdir {binpath}\\\n",
    "        --joined-dictionary\\\n",
    "        --workers 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Configuration for Experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = Namespace(\n",
    "    datadir = \"./DATA/data-bin/ted2020\",\n",
    "    savedir = \"./checkpoints/rnn\",\n",
    "    source_lang = \"en\",\n",
    "    target_lang = \"zh\",\n",
    "    \n",
    "    # cpu threads when fetching & processing data.\n",
    "    num_workers=2,  \n",
    "    # batch size in terms of tokens. gradient accumulation increases the effective batchsize.\n",
    "    max_tokens=8192,\n",
    "    accum_steps=2,\n",
    "    \n",
    "    # the lr s calculated from Noam lr scheduler. you can tune the maximum lr by this factor.\n",
    "    lr_factor=2.,\n",
    "    lr_warmup=4000,\n",
    "    \n",
    "    # clipping gradient norm helps alleviate gradient exploding\n",
    "    clip_norm=1.0,\n",
    "    \n",
    "    # maximum epochs for training\n",
    "    max_epoch=30,\n",
    "    start_epoch=1,\n",
    "    \n",
    "    # beam size for beam search\n",
    "    beam=5, \n",
    "    # generate sequences of maximum length ax + b, where x is the source length\n",
    "    max_len_a=1.2, \n",
    "    max_len_b=10, \n",
    "    # when decoding, post process sentence by removing sentencepiece symbols and jieba tokenization.\n",
    "    post_process = \"sentencepiece\",\n",
    "    \n",
    "    # checkpoints\n",
    "    keep_last_epochs=5,\n",
    "    resume=None, # if resume from checkpoint name (under config.savedir)\n",
    "    \n",
    "    # logging\n",
    "    use_wandb=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Logging\n",
    "- logging package logs ordinary messages\n",
    "- wandb logs the loss, bleu, etc. in the training process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "logging.basicConfig(\n",
    "    format=\"%(asctime)s | %(levelname)s | %(name)s | %(message)s\",\n",
    "    datefmt=\"%Y-%m-%d %H:%M:%S\",\n",
    "    level=\"INFO\", # \"DEBUG\" \"WARNING\" \"ERROR\"\n",
    "    stream=sys.stdout,\n",
    ")\n",
    "proj = \"hw5.seq2seq\"\n",
    "logger = logging.getLogger(proj)\n",
    "if config.use_wandb:\n",
    "    import wandb\n",
    "    wandb.init(project=proj, name=Path(config.savedir).stem, config=config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CUDA Environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cuda_env = utils.CudaEnvironment()\n",
    "utils.CudaEnvironment.pretty_print_cuda_env_list([cuda_env])\n",
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dataloading"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## We borrow the TranslationTask from fairseq\n",
    "* used to load the binarized data created above\n",
    "* well-implemented data iterator (dataloader)\n",
    "* built-in task.source_dictionary and task.target_dictionary are also handy\n",
    "* well-implemented beach search decoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fairseq.tasks.translation import TranslationConfig, TranslationTask\n",
    "\n",
    "## setup task\n",
    "task_cfg = TranslationConfig(\n",
    "    data=config.datadir,\n",
    "    source_lang=config.source_lang,\n",
    "    target_lang=config.target_lang,\n",
    "    train_subset=\"train\",\n",
    "    required_seq_len_multiple=8,\n",
    "    dataset_impl=\"mmap\",\n",
    "    upsample_primary=1,\n",
    ")\n",
    "task = TranslationTask.setup_task(task_cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logger.info(\"loading data for epoch 1\")\n",
    "task.load_dataset(split=\"train\", epoch=1, combine=True) # combine if you have back-translation data.\n",
    "task.load_dataset(split=\"valid\", epoch=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = task.dataset(\"valid\")[1]\n",
    "pprint.pprint(sample)\n",
    "pprint.pprint(\n",
    "    \"Source: \" + \\\n",
    "    task.source_dictionary.string(\n",
    "        sample['source'],\n",
    "        config.post_process,\n",
    "    )\n",
    ")\n",
    "pprint.pprint(\n",
    "    \"Target: \" + \\\n",
    "    task.target_dictionary.string(\n",
    "        sample['target'],\n",
    "        config.post_process,\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset Iterator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* Controls every batch to contain no more than N tokens, which optimizes GPU memory efficiency\n",
    "* Shuffles the training set for every epoch\n",
    "* Ignore sentences exceeding maximum length\n",
    "* Pad all sentences in a batch to the same length, which enables parallel computing by GPU\n",
    "* Add eos and shift one token\n",
    "    - teacher forcing: to train the model to predict the next token based on prefix, we feed the right shifted target sequence as the decoder input.\n",
    "    - generally, prepending bos to the target would do the job (as shown below)\n",
    "![seq2seq](https://i.imgur.com/0zeDyuI.png)\n",
    "    - in fairseq however, this is done by moving the eos token to the begining. Empirically, this has the same effect. For instance:\n",
    "    ```\n",
    "    # output target (target) and Decoder input (prev_output_tokens): \n",
    "                   eos = 2\n",
    "                target = 419,  711,  238,  888,  792,   60,  968,    8,    2\n",
    "    prev_output_tokens = 2,  419,  711,  238,  888,  792,   60,  968,    8\n",
    "    ```\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_data_iterator(task, split, epoch=1, max_tokens=4000, num_workers=1, cached=True):\n",
    "    batch_iterator = task.get_batch_iterator(\n",
    "        dataset=task.dataset(split),\n",
    "        max_tokens=max_tokens,\n",
    "        max_sentences=None,\n",
    "        max_positions=utils.resolve_max_positions(\n",
    "            task.max_positions(),\n",
    "            max_tokens,\n",
    "        ),\n",
    "        ignore_invalid_inputs=True,\n",
    "        seed=seed,\n",
    "        num_workers=num_workers,\n",
    "        epoch=epoch,\n",
    "        disable_iterator_cache=not cached,\n",
    "        # Set this to False to speed up. However, if set to False, changing max_tokens beyond \n",
    "        # first call of this method has no effect. \n",
    "    )\n",
    "    return batch_iterator\n",
    "\n",
    "demo_epoch_obj = load_data_iterator(task, \"valid\", epoch=1, max_tokens=20, num_workers=1, cached=False)\n",
    "demo_iter = demo_epoch_obj.next_epoch_itr(shuffle=True)\n",
    "sample = next(demo_iter)\n",
    "sample"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* each batch is a python dict, with string key and Tensor value. Contents are described below:\n",
    "```python\n",
    "batch = {\n",
    "    \"id\": id, # id for each example \n",
    "    \"nsentences\": len(samples), # batch size (sentences)\n",
    "    \"ntokens\": ntokens, # batch size (tokens)\n",
    "    \"net_input\": {\n",
    "        \"src_tokens\": src_tokens, # sequence in source language\n",
    "        \"src_lengths\": src_lengths, # sequence length of each example before padding\n",
    "        \"prev_output_tokens\": prev_output_tokens, # right shifted target, as mentioned above.\n",
    "    },\n",
    "    \"target\": target, # target sequence\n",
    "}\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Architecture\n",
    "* We again inherit fairseq's encoder, decoder and model, so that in the testing phase we can directly leverage fairseq's beam search decoder."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fairseq.models import (\n",
    "    FairseqEncoder, \n",
    "    FairseqIncrementalDecoder,\n",
    "    FairseqEncoderDecoderModel\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Encoder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- The Encoder is a RNN or Transformer Encoder. The following description is for RNN. For every input token, Encoder will generate a output vector and a hidden states vector, and the hidden states vector is passed on to the next step. In other words, the Encoder sequentially reads in the input sequence, and outputs a single vector at each timestep, then finally outputs the final hidden states, or content vector, at the last timestep.\n",
    "- Parameters:\n",
    "  - *args*\n",
    "      - encoder_embed_dim: the dimension of embeddings, this compresses the one-hot vector into fixed dimensions, which achieves dimension reduction\n",
    "      - encoder_ffn_embed_dim is the dimension of hidden states and output vectors\n",
    "      - encoder_layers is the number of layers for Encoder RNN\n",
    "      - dropout determines the probability of a neuron's activation being set to 0, in order to prevent overfitting. Generally this is applied in training, and removed in testing.\n",
    "  - *dictionary*: the dictionary provided by fairseq. it's used to obtain the padding index, and in turn the encoder padding mask. \n",
    "  - *embed_tokens*: an instance of token embeddings (nn.Embedding)\n",
    "\n",
    "- Inputs: \n",
    "    - *src_tokens*: integer sequence representing english e.g. 1, 28, 29, 205, 2 \n",
    "- Outputs: \n",
    "    - *outputs*: the output of RNN at each timestep, can be furthur processed by Attention\n",
    "    - *final_hiddens*: the hidden states of each timestep, will be passed to decoder for decoding\n",
    "    - *encoder_padding_mask*: this tells the decoder which position to ignore\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RNNEncoder(FairseqEncoder):\n",
    "    def __init__(self, args, dictionary, embed_tokens):\n",
    "        super().__init__(dictionary)\n",
    "        self.embed_tokens = embed_tokens\n",
    "        \n",
    "        self.embed_dim = args.encoder_embed_dim\n",
    "        self.hidden_dim = args.encoder_ffn_embed_dim\n",
    "        self.num_layers = args.encoder_layers\n",
    "        \n",
    "        self.dropout_in_module = nn.Dropout(args.dropout)\n",
    "        self.rnn = nn.GRU(\n",
    "            self.embed_dim, \n",
    "            self.hidden_dim, \n",
    "            self.num_layers, \n",
    "            dropout=args.dropout, \n",
    "            batch_first=False, \n",
    "            bidirectional=True\n",
    "        )\n",
    "        self.dropout_out_module = nn.Dropout(args.dropout)\n",
    "        \n",
    "        self.padding_idx = dictionary.pad()\n",
    "        \n",
    "    def combine_bidir(self, outs, bsz: int):\n",
    "        out = outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous()\n",
    "        return out.view(self.num_layers, bsz, -1)\n",
    "\n",
    "    def forward(self, src_tokens, **unused):\n",
    "        bsz, seqlen = src_tokens.size()\n",
    "        \n",
    "        # get embeddings\n",
    "        x = self.embed_tokens(src_tokens)\n",
    "        x = self.dropout_in_module(x)\n",
    "\n",
    "        # B x T x C -> T x B x C\n",
    "        x = x.transpose(0, 1)\n",
    "        \n",
    "        # pass thru bidirectional RNN\n",
    "        h0 = x.new_zeros(2 * self.num_layers, bsz, self.hidden_dim)\n",
    "        x, final_hiddens = self.rnn(x, h0)\n",
    "        outputs = self.dropout_out_module(x)\n",
    "        # outputs = [sequence len, batch size, hid dim * directions]\n",
    "        # hidden =  [num_layers * directions, batch size  , hid dim]\n",
    "        \n",
    "        # Since Encoder is bidirectional, we need to concatenate the hidden states of two directions\n",
    "        final_hiddens = self.combine_bidir(final_hiddens, bsz)\n",
    "        # hidden =  [num_layers x batch x num_directions*hidden]\n",
    "        \n",
    "        encoder_padding_mask = src_tokens.eq(self.padding_idx).t()\n",
    "        return tuple(\n",
    "            (\n",
    "                outputs,  # seq_len x batch x hidden\n",
    "                final_hiddens,  # num_layers x batch x num_directions*hidden\n",
    "                encoder_padding_mask,  # seq_len x batch\n",
    "            )\n",
    "        )\n",
    "    \n",
    "    def reorder_encoder_out(self, encoder_out, new_order):\n",
    "        # This is used by fairseq's beam search. How and why is not particularly important here.\n",
    "        return tuple(\n",
    "            (\n",
    "                encoder_out[0].index_select(1, new_order),\n",
    "                encoder_out[1].index_select(1, new_order),\n",
    "                encoder_out[2].index_select(1, new_order),\n",
    "            )\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Attention"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- When the input sequence is long, \"content vector\" alone cannot accurately represent the whole sequence, attention mechanism can provide the Decoder more information.\n",
    "- According to the **Decoder embeddings** of the current timestep, match the **Encoder outputs** with decoder embeddings to determine correlation, and then sum the Encoder outputs weighted by the correlation as the input to **Decoder** RNN.\n",
    "- Common attention implementations use neural network / dot product as the correlation between **query** (decoder embeddings) and **key** (Encoder outputs), followed by **softmax**  to obtain a distribution, and finally **values** (Encoder outputs) is **weighted sum**-ed by said distribution.\n",
    "\n",
    "- Parameters:\n",
    "  - *input_embed_dim*: dimensionality of key, should be that of the vector in decoder to attend others\n",
    "  - *source_embed_dim*: dimensionality of query, should be that of the vector to be attended to (encoder outputs)\n",
    "  - *output_embed_dim*: dimensionality of value, should be that of the vector after attention, expected by the next layer\n",
    "\n",
    "- Inputs: \n",
    "    - *inputs*: is the key, the vector to attend to others\n",
    "    - *encoder_outputs*:  is the query/value, the vector to be attended to\n",
    "    - *encoder_padding_mask*: this tells the decoder which position to ignore\n",
    "- Outputs: \n",
    "    - *output*: the context vector after attention\n",
    "    - *attention score*: the attention distribution\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AttentionLayer(nn.Module):\n",
    "    def __init__(self, input_embed_dim, source_embed_dim, output_embed_dim, bias=False):\n",
    "        super().__init__()\n",
    "\n",
    "        self.input_proj = nn.Linear(input_embed_dim, source_embed_dim, bias=bias)\n",
    "        self.output_proj = nn.Linear(\n",
    "            input_embed_dim + source_embed_dim, output_embed_dim, bias=bias\n",
    "        )\n",
    "\n",
    "    def forward(self, inputs, encoder_outputs, encoder_padding_mask):\n",
    "        # inputs: T, B, dim\n",
    "        # encoder_outputs: S x B x dim\n",
    "        # padding mask:  S x B\n",
    "        \n",
    "        # convert all to batch first\n",
    "        inputs = inputs.transpose(1,0) # B, T, dim\n",
    "        encoder_outputs = encoder_outputs.transpose(1,0) # B, S, dim\n",
    "        encoder_padding_mask = encoder_padding_mask.transpose(1,0) # B, S\n",
    "        \n",
    "        # project to the dimensionality of encoder_outputs\n",
    "        x = self.input_proj(inputs)\n",
    "\n",
    "        # compute attention\n",
    "        # (B, T, dim) x (B, dim, S) = (B, T, S)\n",
    "        attn_scores = torch.bmm(x, encoder_outputs.transpose(1,2))\n",
    "\n",
    "        # cancel the attention at positions corresponding to padding\n",
    "        if encoder_padding_mask is not None:\n",
    "            # leveraging broadcast  B, S -> (B, 1, S)\n",
    "            encoder_padding_mask = encoder_padding_mask.unsqueeze(1)\n",
    "            attn_scores = (\n",
    "                attn_scores.float()\n",
    "                .masked_fill_(encoder_padding_mask, float(\"-inf\"))\n",
    "                .type_as(attn_scores)\n",
    "            )  # FP16 support: cast to float and back\n",
    "\n",
    "        # softmax on the dimension corresponding to source sequence\n",
    "        attn_scores = F.softmax(attn_scores, dim=-1)\n",
    "\n",
    "        # shape (B, T, S) x (B, S, dim) = (B, T, dim) weighted sum\n",
    "        x = torch.bmm(attn_scores, encoder_outputs)\n",
    "\n",
    "        # (B, T, dim)\n",
    "        x = torch.cat((x, inputs), dim=-1)\n",
    "        x = torch.tanh(self.output_proj(x)) # concat + linear + tanh\n",
    "        \n",
    "        # restore shape (B, T, dim) -> (T, B, dim)\n",
    "        return x.transpose(1,0), attn_scores"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Decoder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* The hidden states of **Decoder** will be initialized by the final hidden states of **Encoder** (the content vector)\n",
    "* At the same time, **Decoder** will change its hidden states based on the input of the current timestep (the outputs of previous timesteps), and generates an output\n",
    "* Attention improves the performance\n",
    "* The seq2seq steps are implemented in decoder, so that later the Seq2Seq class can accept RNN and Transformer, without furthur modification.\n",
    "- Parameters:\n",
    "  - *args*\n",
    "      - decoder_embed_dim: is the dimensionality of the decoder embeddings, similar to encoder_embed_dim，\n",
    "      - decoder_ffn_embed_dim: is the dimensionality of the decoder RNN hidden states, similar to encoder_ffn_embed_dim\n",
    "      - decoder_layers: number of layers of RNN decoder\n",
    "      - share_decoder_input_output_embed: usually, the projection matrix of the decoder will share weights with the decoder input embeddings\n",
    "  - *dictionary*: the dictionary provided by fairseq\n",
    "  - *embed_tokens*: an instance of token embeddings (nn.Embedding)\n",
    "- Inputs: \n",
    "    - *prev_output_tokens*: integer sequence representing the right-shifted target e.g. 1, 28, 29, 205, 2 \n",
    "    - *encoder_out*: encoder's output.\n",
    "    - *incremental_state*: in order to speed up decoding during test time, we will save the hidden state of each timestep. see forward() for details.\n",
    "- Outputs: \n",
    "    - *outputs*: the logits (before softmax) output of decoder for each timesteps\n",
    "    - *extra*: unsused"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RNNDecoder(FairseqIncrementalDecoder):\n",
    "    def __init__(self, args, dictionary, embed_tokens):\n",
    "        super().__init__(dictionary)\n",
    "        self.embed_tokens = embed_tokens\n",
    "        \n",
    "        assert args.decoder_layers == args.encoder_layers, f\"\"\"seq2seq rnn requires that encoder \n",
    "        and decoder have same layers of rnn. got: {args.encoder_layers, args.decoder_layers}\"\"\"\n",
    "        assert args.decoder_ffn_embed_dim == args.encoder_ffn_embed_dim*2, f\"\"\"seq2seq-rnn requires \n",
    "        that decoder hidden to be 2*encoder hidden dim. got: {args.decoder_ffn_embed_dim, args.encoder_ffn_embed_dim*2}\"\"\"\n",
    "        \n",
    "        self.embed_dim = args.decoder_embed_dim\n",
    "        self.hidden_dim = args.decoder_ffn_embed_dim\n",
    "        self.num_layers = args.decoder_layers\n",
    "        \n",
    "        \n",
    "        self.dropout_in_module = nn.Dropout(args.dropout)\n",
    "        self.rnn = nn.GRU(\n",
    "            self.embed_dim, \n",
    "            self.hidden_dim, \n",
    "            self.num_layers, \n",
    "            dropout=args.dropout, \n",
    "            batch_first=False, \n",
    "            bidirectional=False\n",
    "        )\n",
    "        self.attention = AttentionLayer(\n",
    "            self.embed_dim, self.hidden_dim, self.embed_dim, bias=False\n",
    "        ) \n",
    "        # self.attention = None\n",
    "        self.dropout_out_module = nn.Dropout(args.dropout)\n",
    "        \n",
    "        if self.hidden_dim != self.embed_dim:\n",
    "            self.project_out_dim = nn.Linear(self.hidden_dim, self.embed_dim)\n",
    "        else:\n",
    "            self.project_out_dim = None\n",
    "        \n",
    "        if args.share_decoder_input_output_embed:\n",
    "            self.output_projection = nn.Linear(\n",
    "                self.embed_tokens.weight.shape[1],\n",
    "                self.embed_tokens.weight.shape[0],\n",
    "                bias=False,\n",
    "            )\n",
    "            self.output_projection.weight = self.embed_tokens.weight\n",
    "        else:\n",
    "            self.output_projection = nn.Linear(\n",
    "                self.output_embed_dim, len(dictionary), bias=False\n",
    "            )\n",
    "            nn.init.normal_(\n",
    "                self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5\n",
    "            )\n",
    "        \n",
    "    def forward(self, prev_output_tokens, encoder_out, incremental_state=None, **unused):\n",
    "        # extract the outputs from encoder\n",
    "        encoder_outputs, encoder_hiddens, encoder_padding_mask = encoder_out\n",
    "        # outputs:          seq_len x batch x num_directions*hidden\n",
    "        # encoder_hiddens:  num_layers x batch x num_directions*encoder_hidden\n",
    "        # padding_mask:     seq_len x batch\n",
    "        \n",
    "        if incremental_state is not None and len(incremental_state) > 0:\n",
    "            # if the information from last timestep is retained, we can continue from there instead of starting from bos\n",
    "            prev_output_tokens = prev_output_tokens[:, -1:]\n",
    "            cache_state = self.get_incremental_state(incremental_state, \"cached_state\")\n",
    "            prev_hiddens = cache_state[\"prev_hiddens\"]\n",
    "        else:\n",
    "            # incremental state does not exist, either this is training time, or the first timestep of test time\n",
    "            # prepare for seq2seq: pass the encoder_hidden to the decoder hidden states\n",
    "            prev_hiddens = encoder_hiddens\n",
    "        \n",
    "        bsz, seqlen = prev_output_tokens.size()\n",
    "        \n",
    "        # embed tokens\n",
    "        x = self.embed_tokens(prev_output_tokens)\n",
    "        x = self.dropout_in_module(x)\n",
    "\n",
    "        # B x T x C -> T x B x C\n",
    "        x = x.transpose(0, 1)\n",
    "                \n",
    "        # decoder-to-encoder attention\n",
    "        if self.attention is not None:\n",
    "            x, attn = self.attention(x, encoder_outputs, encoder_padding_mask)\n",
    "                        \n",
    "        # pass thru unidirectional RNN\n",
    "        x, final_hiddens = self.rnn(x, prev_hiddens)\n",
    "        # outputs = [sequence len, batch size, hid dim]\n",
    "        # hidden =  [num_layers * directions, batch size  , hid dim]\n",
    "        x = self.dropout_out_module(x)\n",
    "                \n",
    "        # project to embedding size (if hidden differs from embed size, and share_embedding is True, \n",
    "        # we need to do an extra projection)\n",
    "        if self.project_out_dim != None:\n",
    "            x = self.project_out_dim(x)\n",
    "        \n",
    "        # project to vocab size\n",
    "        x = self.output_projection(x)\n",
    "        \n",
    "        # T x B x C -> B x T x C\n",
    "        x = x.transpose(1, 0)\n",
    "        \n",
    "        # if incremental, record the hidden states of current timestep, which will be restored in the next timestep\n",
    "        cache_state = {\n",
    "            \"prev_hiddens\": final_hiddens,\n",
    "        }\n",
    "        self.set_incremental_state(incremental_state, \"cached_state\", cache_state)\n",
    "        \n",
    "        return x, None\n",
    "    \n",
    "    def reorder_incremental_state(\n",
    "        self,\n",
    "        incremental_state,\n",
    "        new_order,\n",
    "    ):\n",
    "        # This is used by fairseq's beam search. How and why is not particularly important here.\n",
    "        cache_state = self.get_incremental_state(incremental_state, \"cached_state\")\n",
    "        prev_hiddens = cache_state[\"prev_hiddens\"]\n",
    "        prev_hiddens = [p.index_select(0, new_order) for p in prev_hiddens]\n",
    "        cache_state = {\n",
    "            \"prev_hiddens\": torch.stack(prev_hiddens),\n",
    "        }\n",
    "        self.set_incremental_state(incremental_state, \"cached_state\", cache_state)\n",
    "        return"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Seq2Seq\n",
    "- Composed of **Encoder** and **Decoder**\n",
    "- Recieves inputs and pass to **Encoder** \n",
    "- Pass the outputs from **Encoder** to **Decoder**\n",
    "- **Decoder** will decode according to outputs of previous timesteps as well as **Encoder** outputs  \n",
    "- Once done decoding, return the **Decoder** outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Seq2Seq(FairseqEncoderDecoderModel):\n",
    "    def __init__(self, args, encoder, decoder):\n",
    "        super().__init__(encoder, decoder)\n",
    "        self.args = args\n",
    "    \n",
    "    def forward(\n",
    "        self,\n",
    "        src_tokens,\n",
    "        src_lengths,\n",
    "        prev_output_tokens,\n",
    "        return_all_hiddens: bool = True,\n",
    "    ):\n",
    "        \"\"\"\n",
    "        Run the forward pass for an encoder-decoder model.\n",
    "        \"\"\"\n",
    "        encoder_out = self.encoder(\n",
    "            src_tokens, src_lengths=src_lengths, return_all_hiddens=return_all_hiddens\n",
    "        )\n",
    "        logits, extra = self.decoder(\n",
    "            prev_output_tokens,\n",
    "            encoder_out=encoder_out,\n",
    "            src_lengths=src_lengths,\n",
    "            return_all_hiddens=return_all_hiddens,\n",
    "        )\n",
    "        return logits, extra"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # HINT: transformer architecture\n",
    "# from fairseq.models.transformer import (\n",
    "#     TransformerEncoder, \n",
    "#     TransformerDecoder,\n",
    "# )\n",
    "\n",
    "def build_model(args, task):\n",
    "    \"\"\" build a model instance based on hyperparameters \"\"\"\n",
    "    src_dict, tgt_dict = task.source_dictionary, task.target_dictionary\n",
    "\n",
    "    # token embeddings\n",
    "    encoder_embed_tokens = nn.Embedding(len(src_dict), args.encoder_embed_dim, src_dict.pad())\n",
    "    decoder_embed_tokens = nn.Embedding(len(tgt_dict), args.decoder_embed_dim, tgt_dict.pad())\n",
    "    \n",
    "    # encoder decoder\n",
    "    # HINT: TODO: switch to TransformerEncoder & TransformerDecoder\n",
    "    encoder = RNNEncoder(args, src_dict, encoder_embed_tokens)\n",
    "    decoder = RNNDecoder(args, tgt_dict, decoder_embed_tokens)\n",
    "    \n",
    "    # sequence to sequence model\n",
    "    model = Seq2Seq(args, encoder, decoder)\n",
    "    \n",
    "    # initialization for seq2seq model is important, requires extra handling\n",
    "    def init_params(module):\n",
    "        from fairseq.modules import MultiheadAttention\n",
    "        if isinstance(module, nn.Linear):\n",
    "            module.weight.data.normal_(mean=0.0, std=0.02)\n",
    "            if module.bias is not None:\n",
    "                module.bias.data.zero_()\n",
    "        if isinstance(module, nn.Embedding):\n",
    "            module.weight.data.normal_(mean=0.0, std=0.02)\n",
    "            if module.padding_idx is not None:\n",
    "                module.weight.data[module.padding_idx].zero_()\n",
    "        if isinstance(module, MultiheadAttention):\n",
    "            module.q_proj.weight.data.normal_(mean=0.0, std=0.02)\n",
    "            module.k_proj.weight.data.normal_(mean=0.0, std=0.02)\n",
    "            module.v_proj.weight.data.normal_(mean=0.0, std=0.02)\n",
    "        if isinstance(module, nn.RNNBase):\n",
    "            for name, param in module.named_parameters():\n",
    "                if \"weight\" in name or \"bias\" in name:\n",
    "                    param.data.uniform_(-0.1, 0.1)\n",
    "            \n",
    "    # weight initialization\n",
    "    model.apply(init_params)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Architecture Related Configuration\n",
    "reference implementation\n",
    "\n",
    "|model|embedding dim|encoder ffn|encoder layers|decoder ffn|decoder layers|\n",
    "|-|-|-|-|-|-|\n",
    "|RNN|256|512|1|1024|1|\n",
    "|Transformer|256|1024|4|1024|4|\n",
    "\n",
    "For strong baseline, please refer to the hyperparameters for *transformer-base* in Table 3 in [Attention is all you need](#vaswani2017)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "arch_args = Namespace(\n",
    "    encoder_embed_dim=256,\n",
    "    encoder_ffn_embed_dim=512,\n",
    "    encoder_layers=1,\n",
    "    decoder_embed_dim=256,\n",
    "    decoder_ffn_embed_dim=1024,\n",
    "    decoder_layers=1,\n",
    "    share_decoder_input_output_embed=True,\n",
    "    dropout=0.3,\n",
    ")\n",
    "\n",
    "# # HINT: these patches on parameters for Transformer\n",
    "# def add_transformer_args(args):\n",
    "#     args.encoder_attention_heads=4\n",
    "#     args.encoder_normalize_before=True\n",
    "    \n",
    "#     args.decoder_attention_heads=4\n",
    "#     args.decoder_normalize_before=True\n",
    "    \n",
    "#     args.activation_fn=\"relu\"\n",
    "#     args.max_source_positions=1024\n",
    "#     args.max_target_positions=1024\n",
    "    \n",
    "#     # patches on default parameters for Transformer (those not set above)\n",
    "#     from fairseq.models.transformer import base_architecture\n",
    "#     base_architecture(arch_args)\n",
    "\n",
    "# add_transformer_args(arch_args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if config.use_wandb:\n",
    "    wandb.config.update(vars(arch_args))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = build_model(arch_args, task)\n",
    "logger.info(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Optimization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loss: Label Smoothing Regularization\n",
    "* let the model learn to generate less concentrated distribution, and prevent over-confidence\n",
    "* sometimes the ground truth may not be the only answer. thus, when calculating loss, we reserve some probability for incorrect labels\n",
    "* avoids overfitting\n",
    "\n",
    "code [source](https://fairseq.readthedocs.io/en/latest/_modules/fairseq/criterions/label_smoothed_cross_entropy.html)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LabelSmoothedCrossEntropyCriterion(nn.Module):\n",
    "    def __init__(self, smoothing, ignore_index=None, reduce=True):\n",
    "        super().__init__()\n",
    "        self.smoothing = smoothing\n",
    "        self.ignore_index = ignore_index\n",
    "        self.reduce = reduce\n",
    "    \n",
    "    def forward(self, lprobs, target):\n",
    "        if target.dim() == lprobs.dim() - 1:\n",
    "            target = target.unsqueeze(-1)\n",
    "        # nll: Negative log likelihood，the cross-entropy when target is one-hot. following line is same as F.nll_loss\n",
    "        nll_loss = -lprobs.gather(dim=-1, index=target)\n",
    "        #  reserve some probability for other labels. thus when calculating cross-entropy, \n",
    "        # equivalent to summing the log probs of all labels\n",
    "        smooth_loss = -lprobs.sum(dim=-1, keepdim=True)\n",
    "        if self.ignore_index is not None:\n",
    "            pad_mask = target.eq(self.ignore_index)\n",
    "            nll_loss.masked_fill_(pad_mask, 0.0)\n",
    "            smooth_loss.masked_fill_(pad_mask, 0.0)\n",
    "        else:\n",
    "            nll_loss = nll_loss.squeeze(-1)\n",
    "            smooth_loss = smooth_loss.squeeze(-1)\n",
    "        if self.reduce:\n",
    "            nll_loss = nll_loss.sum()\n",
    "            smooth_loss = smooth_loss.sum()\n",
    "        # when calculating cross-entropy, add the loss of other labels\n",
    "        eps_i = self.smoothing / lprobs.size(-1)\n",
    "        loss = (1.0 - self.smoothing) * nll_loss + eps_i * smooth_loss\n",
    "        return loss\n",
    "\n",
    "# generally, 0.1 is good enough\n",
    "criterion = LabelSmoothedCrossEntropyCriterion(\n",
    "    smoothing=0.1,\n",
    "    ignore_index=task.target_dictionary.pad(),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Optimizer: Adam + lr scheduling\n",
    "Inverse square root scheduling is important to the stability when training Transformer. It's later used on RNN as well.\n",
    "Update the learning rate according to the following equation. Linearly increase the first stage, then decay proportionally to the inverse square root of timestep.\n",
    "$$lrate = d_{\\text{model}}^{-0.5}\\cdot\\min({step\\_num}^{-0.5},{step\\_num}\\cdot{warmup\\_steps}^{-1.5})$$\n",
    "code [source](https://nlp.seas.harvard.edu/2018/04/03/attention.html)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NoamOpt:\n",
    "    \"Optim wrapper that implements rate.\"\n",
    "    def __init__(self, model_size, factor, warmup, optimizer):\n",
    "        self.optimizer = optimizer\n",
    "        self._step = 0\n",
    "        self.warmup = warmup\n",
    "        self.factor = factor\n",
    "        self.model_size = model_size\n",
    "        self._rate = 0\n",
    "    \n",
    "    @property\n",
    "    def param_groups(self):\n",
    "        return self.optimizer.param_groups\n",
    "        \n",
    "    def multiply_grads(self, c):\n",
    "        \"\"\"Multiplies grads by a constant *c*.\"\"\"                \n",
    "        for group in self.param_groups:\n",
    "            for p in group['params']:\n",
    "                if p.grad is not None:\n",
    "                    p.grad.data.mul_(c)\n",
    "        \n",
    "    def step(self):\n",
    "        \"Update parameters and rate\"\n",
    "        self._step += 1\n",
    "        rate = self.rate()\n",
    "        for p in self.param_groups:\n",
    "            p['lr'] = rate\n",
    "        self._rate = rate\n",
    "        self.optimizer.step()\n",
    "        \n",
    "    def rate(self, step = None):\n",
    "        \"Implement `lrate` above\"\n",
    "        if step is None:\n",
    "            step = self._step\n",
    "        return 0 if not step else self.factor * \\\n",
    "            (self.model_size ** (-0.5) *\n",
    "            min(step ** (-0.5), step * self.warmup ** (-1.5)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Scheduling Visualized"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = NoamOpt(\n",
    "    model_size=arch_args.encoder_embed_dim, \n",
    "    factor=config.lr_factor, \n",
    "    warmup=config.lr_warmup, \n",
    "    optimizer=torch.optim.AdamW(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=0.0001))\n",
    "plt.plot(np.arange(1, 100000), [optimizer.rate(i) for i in range(1, 100000)])\n",
    "plt.legend([f\"{optimizer.model_size}:{optimizer.warmup}\"])\n",
    "None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training Procedure"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fairseq.data import iterators\n",
    "from torch.cuda.amp import GradScaler, autocast\n",
    "\n",
    "def train_one_epoch(epoch_itr, model, task, criterion, optimizer, accum_steps=1):\n",
    "    itr = epoch_itr.next_epoch_itr(shuffle=True)\n",
    "    itr = iterators.GroupedIterator(itr, accum_steps) # gradient accumulation: update every accum_steps samples\n",
    "    \n",
    "    stats = {\"loss\": []}\n",
    "    scaler = GradScaler() # automatic mixed precision (amp) \n",
    "    \n",
    "    model.train()\n",
    "    progress = tqdm.tqdm(itr, desc=f\"train epoch {epoch_itr.epoch}\", leave=False)\n",
    "    for samples in progress:\n",
    "        model.zero_grad()\n",
    "        accum_loss = 0\n",
    "        sample_size = 0\n",
    "        # gradient accumulation: update every accum_steps samples\n",
    "        for i, sample in enumerate(samples):\n",
    "            if i == 1:\n",
    "                # emptying the CUDA cache after the first step can reduce the chance of OOM\n",
    "                torch.cuda.empty_cache()\n",
    "\n",
    "            sample = utils.move_to_cuda(sample, device=device)\n",
    "            target = sample[\"target\"]\n",
    "            sample_size_i = sample[\"ntokens\"]\n",
    "            sample_size += sample_size_i\n",
    "            \n",
    "            # mixed precision training\n",
    "            with autocast():\n",
    "                net_output = model.forward(**sample[\"net_input\"])\n",
    "                lprobs = F.log_softmax(net_output[0], -1)            \n",
    "                loss = criterion(lprobs.view(-1, lprobs.size(-1)), target.view(-1))\n",
    "                \n",
    "                # logging\n",
    "                accum_loss += loss.item()\n",
    "                # back-prop\n",
    "                scaler.scale(loss).backward()                \n",
    "        \n",
    "        scaler.unscale_(optimizer)\n",
    "        optimizer.multiply_grads(1 / (sample_size or 1.0)) # (sample_size or 1.0) handles the case of a zero gradient\n",
    "        gnorm = nn.utils.clip_grad_norm_(model.parameters(), config.clip_norm) # grad norm clipping prevents gradient exploding\n",
    "        \n",
    "        scaler.step(optimizer)\n",
    "        scaler.update()\n",
    "        \n",
    "        # logging\n",
    "        loss_print = accum_loss/sample_size\n",
    "        stats[\"loss\"].append(loss_print)\n",
    "        progress.set_postfix(loss=loss_print)\n",
    "        if config.use_wandb:\n",
    "            wandb.log({\n",
    "                \"train/loss\": loss_print,\n",
    "                \"train/grad_norm\": gnorm.item(),\n",
    "                \"train/lr\": optimizer.rate(),\n",
    "                \"train/sample_size\": sample_size,\n",
    "            })\n",
    "        \n",
    "    loss_print = np.mean(stats[\"loss\"])\n",
    "    logger.info(f\"training loss: {loss_print:.4f}\")\n",
    "    return stats"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Validation & Inference\n",
    "To prevent overfitting, validation is required every epoch to validate the performance on unseen data.\n",
    "- the procedure is essensially same as training, with the addition of inference step\n",
    "- after validation we can save the model weights\n",
    "\n",
    "Validation loss alone cannot describe the actual performance of the model\n",
    "- Directly produce translation hypotheses based on current model, then calculate BLEU with the reference translation\n",
    "- We can also manually examine the hypotheses' quality\n",
    "- We use fairseq's sequence generator for beam search to generate translation hypotheses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fairseq's beam search generator\n",
    "# given model and input seqeunce, produce translation hypotheses by beam search\n",
    "sequence_generator = task.build_generator([model], config)\n",
    "\n",
    "def decode(toks, dictionary):\n",
    "    # convert from Tensor to human readable sentence\n",
    "    s = dictionary.string(\n",
    "        toks.int().cpu(),\n",
    "        config.post_process,\n",
    "    )\n",
    "    return s if s else \"<unk>\"\n",
    "\n",
    "def inference_step(sample, model):\n",
    "    gen_out = sequence_generator.generate([model], sample)\n",
    "    srcs = []\n",
    "    hyps = []\n",
    "    refs = []\n",
    "    for i in range(len(gen_out)):\n",
    "        # for each sample, collect the input, hypothesis and reference, later be used to calculate BLEU\n",
    "        srcs.append(decode(\n",
    "            utils.strip_pad(sample[\"net_input\"][\"src_tokens\"][i], task.source_dictionary.pad()), \n",
    "            task.source_dictionary,\n",
    "        ))\n",
    "        hyps.append(decode(\n",
    "            gen_out[i][0][\"tokens\"], # 0 indicates using the top hypothesis in beam\n",
    "            task.target_dictionary,\n",
    "        ))\n",
    "        refs.append(decode(\n",
    "            utils.strip_pad(sample[\"target\"][i], task.target_dictionary.pad()), \n",
    "            task.target_dictionary,\n",
    "        ))\n",
    "    return srcs, hyps, refs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import shutil\n",
    "import sacrebleu\n",
    "\n",
    "def validate(model, task, criterion, log_to_wandb=True):\n",
    "    logger.info('begin validation')\n",
    "    itr = load_data_iterator(task, \"valid\", 1, config.max_tokens, config.num_workers).next_epoch_itr(shuffle=False)\n",
    "    \n",
    "    stats = {\"loss\":[], \"bleu\": 0, \"srcs\":[], \"hyps\":[], \"refs\":[]}\n",
    "    srcs = []\n",
    "    hyps = []\n",
    "    refs = []\n",
    "    \n",
    "    model.eval()\n",
    "    progress = tqdm.tqdm(itr, desc=f\"validation\", leave=False)\n",
    "    with torch.no_grad():\n",
    "        for i, sample in enumerate(progress):\n",
    "            # validation loss\n",
    "            sample = utils.move_to_cuda(sample, device=device)\n",
    "            net_output = model.forward(**sample[\"net_input\"])\n",
    "\n",
    "            lprobs = F.log_softmax(net_output[0], -1)\n",
    "            target = sample[\"target\"]\n",
    "            sample_size = sample[\"ntokens\"]\n",
    "            loss = criterion(lprobs.view(-1, lprobs.size(-1)), target.view(-1)) / sample_size\n",
    "            progress.set_postfix(valid_loss=loss.item())\n",
    "            stats[\"loss\"].append(loss)\n",
    "            \n",
    "            # do inference\n",
    "            s, h, r = inference_step(sample, model)\n",
    "            srcs.extend(s)\n",
    "            hyps.extend(h)\n",
    "            refs.extend(r)\n",
    "            \n",
    "    tok = 'zh' if task.cfg.target_lang == 'zh' else '13a'\n",
    "    stats[\"loss\"] = torch.stack(stats[\"loss\"]).mean().item()\n",
    "    stats[\"bleu\"] = sacrebleu.corpus_bleu(hyps, [refs], tokenize=tok) # 計算BLEU score\n",
    "    stats[\"srcs\"] = srcs\n",
    "    stats[\"hyps\"] = hyps\n",
    "    stats[\"refs\"] = refs\n",
    "    \n",
    "    if config.use_wandb and log_to_wandb:\n",
    "        wandb.log({\n",
    "            \"valid/loss\": stats[\"loss\"],\n",
    "            \"valid/bleu\": stats[\"bleu\"].score,\n",
    "        }, commit=False)\n",
    "    \n",
    "    showid = np.random.randint(len(hyps))\n",
    "    logger.info(\"example source: \" + srcs[showid])\n",
    "    logger.info(\"example hypothesis: \" + hyps[showid])\n",
    "    logger.info(\"example reference: \" + refs[showid])\n",
    "    \n",
    "    # show bleu results\n",
    "    logger.info(f\"validation loss:\\t{stats['loss']:.4f}\")\n",
    "    logger.info(stats[\"bleu\"].format())\n",
    "    return stats"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Save and Load Model Weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def validate_and_save(model, task, criterion, optimizer, epoch, save=True):   \n",
    "    stats = validate(model, task, criterion)\n",
    "    bleu = stats['bleu']\n",
    "    loss = stats['loss']\n",
    "    if save:\n",
    "        # save epoch checkpoints\n",
    "        savedir = Path(config.savedir).absolute()\n",
    "        savedir.mkdir(parents=True, exist_ok=True)\n",
    "        \n",
    "        check = {\n",
    "            \"model\": model.state_dict(),\n",
    "            \"stats\": {\"bleu\": bleu.score, \"loss\": loss},\n",
    "            \"optim\": {\"step\": optimizer._step}\n",
    "        }\n",
    "        torch.save(check, savedir/f\"checkpoint{epoch}.pt\")\n",
    "        shutil.copy(savedir/f\"checkpoint{epoch}.pt\", savedir/f\"checkpoint_last.pt\")\n",
    "        logger.info(f\"saved epoch checkpoint: {savedir}/checkpoint{epoch}.pt\")\n",
    "    \n",
    "        # save epoch samples\n",
    "        with open(savedir/f\"samples{epoch}.{config.source_lang}-{config.target_lang}.txt\", \"w\") as f:\n",
    "            for s, h in zip(stats[\"srcs\"], stats[\"hyps\"]):\n",
    "                f.write(f\"{s}\\t{h}\\n\")\n",
    "\n",
    "        # get best valid bleu    \n",
    "        if getattr(validate_and_save, \"best_bleu\", 0) < bleu.score:\n",
    "            validate_and_save.best_bleu = bleu.score\n",
    "            torch.save(check, savedir/f\"checkpoint_best.pt\")\n",
    "            \n",
    "        del_file = savedir / f\"checkpoint{epoch - config.keep_last_epochs}.pt\"\n",
    "        if del_file.exists():\n",
    "            del_file.unlink()\n",
    "    \n",
    "    return stats\n",
    "\n",
    "def try_load_checkpoint(model, optimizer=None, name=None):\n",
    "    name = name if name else \"checkpoint_last.pt\"\n",
    "    checkpath = Path(config.savedir)/name\n",
    "    if checkpath.exists():\n",
    "        check = torch.load(checkpath)\n",
    "        model.load_state_dict(check[\"model\"])\n",
    "        stats = check[\"stats\"]\n",
    "        step = \"unknown\"\n",
    "        if optimizer != None:\n",
    "            optimizer._step = step = check[\"optim\"][\"step\"]\n",
    "        logger.info(f\"loaded checkpoint {checkpath}: step={step} loss={stats['loss']} bleu={stats['bleu']}\")\n",
    "    else:\n",
    "        logger.info(f\"no checkpoints found at {checkpath}!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Main\n",
    "## Training loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.to(device=device)\n",
    "criterion = criterion.to(device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logger.info(\"task: {}\".format(task.__class__.__name__))\n",
    "logger.info(\"encoder: {}\".format(model.encoder.__class__.__name__))\n",
    "logger.info(\"decoder: {}\".format(model.decoder.__class__.__name__))\n",
    "logger.info(\"criterion: {}\".format(criterion.__class__.__name__))\n",
    "logger.info(\"optimizer: {}\".format(optimizer.__class__.__name__))\n",
    "logger.info(\n",
    "    \"num. model params: {:,} (num. trained: {:,})\".format(\n",
    "        sum(p.numel() for p in model.parameters()),\n",
    "        sum(p.numel() for p in model.parameters() if p.requires_grad),\n",
    "    )\n",
    ")\n",
    "logger.info(f\"max tokens per batch = {config.max_tokens}, accumulate steps = {config.accum_steps}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "epoch_itr = load_data_iterator(task, \"train\", config.start_epoch, config.max_tokens, config.num_workers)\n",
    "try_load_checkpoint(model, optimizer, name=config.resume)\n",
    "while epoch_itr.next_epoch_idx <= config.max_epoch:\n",
    "    # train for one epoch\n",
    "    train_one_epoch(epoch_itr, model, task, criterion, optimizer, config.accum_steps)\n",
    "    stats = validate_and_save(model, task, criterion, optimizer, epoch=epoch_itr.epoch)\n",
    "    logger.info(\"end of epoch {}\".format(epoch_itr.epoch))    \n",
    "    epoch_itr = load_data_iterator(task, \"train\", epoch_itr.next_epoch_idx, config.max_tokens, config.num_workers)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Submission"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# averaging a few checkpoints can have a similar effect to ensemble\n",
    "checkdir=config.savedir\n",
    "!python ./fairseq/scripts/average_checkpoints.py \\\n",
    "--inputs {checkdir} \\\n",
    "--num-epoch-checkpoints 5 \\\n",
    "--output {checkdir}/avg_last_5_checkpoint.pt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Confirm model weights used to generate submission"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# checkpoint_last.pt : latest epoch\n",
    "# checkpoint_best.pt : highest validation bleu\n",
    "# avg_last_5_checkpoint.pt:　the average of last 5 epochs\n",
    "try_load_checkpoint(model, name=\"avg_last_5_checkpoint.pt\")\n",
    "validate(model, task, criterion, log_to_wandb=False)\n",
    "None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate Prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_prediction(model, task, split=\"test\", outfile=\"./prediction.txt\"):    \n",
    "    task.load_dataset(split=split, epoch=1)\n",
    "    itr = load_data_iterator(task, split, 1, config.max_tokens, config.num_workers).next_epoch_itr(shuffle=False)\n",
    "    \n",
    "    idxs = []\n",
    "    hyps = []\n",
    "\n",
    "    model.eval()\n",
    "    progress = tqdm.tqdm(itr, desc=f\"prediction\")\n",
    "    with torch.no_grad():\n",
    "        for i, sample in enumerate(progress):\n",
    "            # validation loss\n",
    "            sample = utils.move_to_cuda(sample, device=device)\n",
    "\n",
    "            # do inference\n",
    "            s, h, r = inference_step(sample, model)\n",
    "            \n",
    "            hyps.extend(h)\n",
    "            idxs.extend(list(sample['id']))\n",
    "            \n",
    "    # sort based on the order before preprocess\n",
    "    hyps = [x for _,x in sorted(zip(idxs,hyps))]\n",
    "    \n",
    "    with open(outfile, \"w\") as f:\n",
    "        for h in hyps:\n",
    "            f.write(h+\"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generate_prediction(model, task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raise"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Back-translation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train a backward translation model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. Switch the source_lang and target_lang in **config** \n",
    "2. Change the savedir in **config** (eg. \"./checkpoints/transformer-back\")\n",
    "3. Train model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate synthetic data with backward model "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download monolingual data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mono_dataset_name = 'mono'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mono_prefix = Path(data_dir).absolute() / mono_dataset_name\n",
    "mono_prefix.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "urls = (\n",
    "    '\"https://onedrive.live.com/download?cid=3E549F3B24B238B4&resid=3E549F3B24B238B4%214986&authkey=AANUKbGfZx0kM80\"',\n",
    "# # If the above links die, use the following instead. \n",
    "#     \"https://www.csie.ntu.edu.tw/~r09922057/ML2021-hw5/ted_zh_corpus.deduped.gz\",\n",
    "# # If the above links die, use the following instead. \n",
    "#     \"https://mega.nz/#!vMNnDShR!4eHDxzlpzIpdpeQTD-htatU_C7QwcBTwGDaSeBqH534\",\n",
    ")\n",
    "file_names = (\n",
    "    'ted_zh_corpus.deduped.gz',\n",
    ")\n",
    "\n",
    "for u, f in zip(urls, file_names):\n",
    "    path = mono_prefix/f\n",
    "    if not path.exists():\n",
    "        if 'mega' in u:\n",
    "            !megadl {u} --path {path}\n",
    "        else:\n",
    "            !wget {u} -O {path}\n",
    "    else:\n",
    "        print(f'{f} is exist, skip downloading')\n",
    "    if path.suffix == \".tgz\":\n",
    "        !tar -xvf {path} -C {prefix}\n",
    "    elif path.suffix == \".zip\":\n",
    "        !unzip -o {path} -d {prefix}\n",
    "    elif path.suffix == \".gz\":\n",
    "        !gzip -fkd {path}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TODO: clean corpus\n",
    "\n",
    "1. remove sentences that are too long or too short\n",
    "2. unify punctuation\n",
    "\n",
    "hint: you can use clean_s() defined above to do this"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TODO: Subword Units\n",
    "\n",
    "Use the spm model of the backward model to tokenize the data into subword units\n",
    "\n",
    "hint: spm model is located at DATA/raw-data/\\[dataset\\]/spm\\[vocab_num\\].model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Binarize\n",
    "\n",
    "use fairseq to binarize data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "binpath = Path('./DATA/data-bin', mono_dataset_name)\n",
    "src_dict_file = './DATA/data-bin/ted2020/dict.en.txt'\n",
    "tgt_dict_file = src_dict_file\n",
    "monopref = str(mono_prefix/\"mono.tok\") # whatever filepath you get after applying subword tokenization\n",
    "if binpath.exists():\n",
    "    print(binpath, \"exists, will not overwrite!\")\n",
    "else:\n",
    "    !python -m fairseq_cli.preprocess\\\n",
    "        --source-lang 'zh'\\\n",
    "        --target-lang 'en'\\\n",
    "        --trainpref {monopref}\\\n",
    "        --destdir {binpath}\\\n",
    "        --srcdict {src_dict_file}\\\n",
    "        --tgtdict {tgt_dict_file}\\\n",
    "        --workers 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TODO: Generate synthetic data with backward model\n",
    "\n",
    "Add binarized monolingual data to the original data directory, and name it with \"split_name\"\n",
    "\n",
    "ex. ./DATA/data-bin/ted2020/\\[split_name\\].zh-en.\\[\"en\", \"zh\"\\].\\[\"bin\", \"idx\"\\]\n",
    "\n",
    "then you can use 'generate_prediction(model, task, split=\"split_name\")' to generate translation prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Add binarized monolingual data to the original data directory, and name it with \"split_name\"\n",
    "# ex. ./DATA/data-bin/ted2020/\\[split_name\\].zh-en.\\[\"en\", \"zh\"\\].\\[\"bin\", \"idx\"\\]\n",
    "!cp ./DATA/data-bin/mono/train.zh-en.zh.bin ./DATA/data-bin/ted2020/mono.zh-en.zh.bin\n",
    "!cp ./DATA/data-bin/mono/train.zh-en.zh.idx ./DATA/data-bin/ted2020/mono.zh-en.zh.idx\n",
    "!cp ./DATA/data-bin/mono/train.zh-en.en.bin ./DATA/data-bin/ted2020/mono.zh-en.en.bin\n",
    "!cp ./DATA/data-bin/mono/train.zh-en.en.idx ./DATA/data-bin/ted2020/mono.zh-en.en.idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# hint: do prediction on split='mono' to create prediction_file\n",
    "# generate_prediction( ... ,split=... ,outfile=... )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TODO: Create new dataset\n",
    "\n",
    "1. Combine the prediction data with monolingual data\n",
    "2. Use the original spm model to tokenize data into Subword Units\n",
    "3. Binarize data with fairseq"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Combine prediction_file (.en) and mono.zh (.zh) into a new dataset.\n",
    "# \n",
    "# hint: tokenize prediction_file with the spm model\n",
    "# spm_model.encode(line, out_type=str)\n",
    "# output: ./DATA/rawdata/mono/mono.tok.en & mono.tok.zh\n",
    "#\n",
    "# hint: use fairseq to binarize these two files again\n",
    "# binpath = Path('./DATA/data-bin/synthetic')\n",
    "# src_dict_file = './DATA/data-bin/ted2020/dict.en.txt'\n",
    "# tgt_dict_file = src_dict_file\n",
    "# monopref = ./DATA/rawdata/mono/mono.tok # or whatever path after applying subword tokenization, w/o the suffix (.zh/.en)\n",
    "# if binpath.exists():\n",
    "#     print(binpath, \"exists, will not overwrite!\")\n",
    "# else:\n",
    "#     !python -m fairseq_cli.preprocess\\\n",
    "#         --source-lang 'zh'\\\n",
    "#         --target-lang 'en'\\\n",
    "#         --trainpref {monopref}\\\n",
    "#         --destdir {binpath}\\\n",
    "#         --srcdict {src_dict_file}\\\n",
    "#         --tgtdict {tgt_dict_file}\\\n",
    "#         --workers 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create a new dataset from all the files prepared above\n",
    "!cp -r ./DATA/data-bin/ted2020/ ./DATA/data-bin/ted2020_with_mono/\n",
    "\n",
    "!cp ./DATA/data-bin/synthetic/train.zh-en.zh.bin ./DATA/data-bin/ted2020_with_mono/train1.en-zh.zh.bin\n",
    "!cp ./DATA/data-bin/synthetic/train.zh-en.zh.idx ./DATA/data-bin/ted2020_with_mono/train1.en-zh.zh.idx\n",
    "!cp ./DATA/data-bin/synthetic/train.zh-en.en.bin ./DATA/data-bin/ted2020_with_mono/train1.en-zh.en.bin\n",
    "!cp ./DATA/data-bin/synthetic/train.zh-en.en.idx ./DATA/data-bin/ted2020_with_mono/train1.en-zh.en.idx"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Created new dataset \"ted2020_with_mono\"\n",
    "\n",
    "1. Change the datadir in **config** (\"./DATA/data-bin/ted2020_with_mono\")\n",
    "2. Switch back the source_lang and target_lang in **config** (\"en\", \"zh\")\n",
    "2. Change the savedir in **config** (eg. \"./checkpoints/transformer-bt\")\n",
    "3. Train model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# References"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. <a name=ott2019fairseq></a>Ott, M., Edunov, S., Baevski, A., Fan, A., Gross, S., Ng, N., ... & Auli, M. (2019, June). fairseq: A Fast, Extensible Toolkit for Sequence Modeling. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics (Demonstrations) (pp. 48-53).\n",
    "2. <a name=vaswani2017></a>Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., ... & Polosukhin, I. (2017, December). Attention is all you need. In Proceedings of the 31st International Conference on Neural Information Processing Systems (pp. 6000-6010).\n",
    "3. <a name=reimers-2020-multilingual-sentence-bert></a>Reimers, N., & Gurevych, I. (2020, November). Making Monolingual Sentence Embeddings Multilingual Using Knowledge Distillation. In Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP) (pp. 4512-4525).\n",
    "4. <a name=tiedemann2012parallel></a>Tiedemann, J. (2012, May). Parallel Data, Tools and Interfaces in OPUS. In Lrec (Vol. 2012, pp. 2214-2218).\n",
    "5. <a name=kudo-richardson-2018-sentencepiece></a>Kudo, T., & Richardson, J. (2018, November). SentencePiece: A simple and language independent subword tokenizer and detokenizer for Neural Text Processing. In Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing: System Demonstrations (pp. 66-71).\n",
    "6. <a name=sennrich-etal-2016-improving></a>Sennrich, R., Haddow, B., & Birch, A. (2016, August). Improving Neural Machine Translation Models with Monolingual Data. In Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers) (pp. 86-96).\n",
    "7. <a name=edunov-etal-2018-understanding></a>Edunov, S., Ott, M., Auli, M., & Grangier, D. (2018). Understanding Back-Translation at Scale. In Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing (pp. 489-500).\n",
    "8. https://github.com/ajinkyakulkarni14/TED-Multilingual-Parallel-Corpus\n",
    "9. https://ithelp.ithome.com.tw/articles/10233122\n",
    "10. https://nlp.seas.harvard.edu/2018/04/03/attention.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
